# coding=utf-8
# Copyright (c) 2016 Tinydot. inc.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.
import tensorflow as tf
import json
import os
from .JsonParse import _
from math import ceil,floor
from . import get_logger

logger = get_logger(__name__)

def load_jsonfile(url):
    """
    读取json文件
    :param url: json文件的路径
    :return: 读取json后得到的字典
    """
    with open(url, 'r') as f:
        data = json.load(f)
    return data


def load_pbfile(url, flag=False):
    """
    读取pb文件
    :param url: pb文件的路径
    :param flag: True返回graph_def，False返回graph
    :return: graph_def或者graph
    """
    graph = tf.Graph()
    with graph.as_default():
        with open(url, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name='')
    if flag:
        return graph_def
    else:
        return graph


class JsonOperation(object):
    """
    解析单个op，并提取特征
    """

    def __init__(self, op, operation, op_ban_list, feature_dict):
        """
        单个op解析器的自带参数的初始化
        :param op: TensorFlow对象，被解析的对象
        :param operation: 该 op type 对应的json
        :param op_ban_list: 禁止录入的op类型
        :param feature_dict: 字典，收录的feature，不在此列表内的特征不收录，同时也有特征的排列顺序
        """
        self.op = op
        self.operation = operation
        self.op_ban_list = op_ban_list
        self.feature_dict = feature_dict

    def resolve(self, result_op, sequence):
        """
        处理一个op，根据其输入，提取它的特征，这一切都由json配置文件的指导下进行
        :param result_op: 字典，截止到此op，提取到的op特征汇总表
        :param sequence: 列表（字典）的嵌套，包含每个op提取到的特征
        :return: 无返回值
        """
        operation_var = self.operation['Var']
        operation_feature = self.operation['Feature']
        var_feature_list = []
        var_feature_list.extend(operation_var)
        var_feature_list.extend(operation_feature)
        locals()['inputs'] = self.op.inputs
        error_flag = True# op分析中，如果前面的表达式有Error，后面不用算了，全标Error
        for key_expression in var_feature_list:
            split_arr = key_expression.split('=',1)
            key = split_arr[0].strip()
            expression = split_arr[1].strip()
            _.op = self.op
            try:
                if error_flag:
                    locals()[key] = eval(expression)
                else:
                    locals()[key] = 'Error'
            except Exception as e:
                print(e)
                error_flag = False
                locals()[key] = 'Error'
            if key in self.feature_dict.keys():
                if self.op.type not in self.op_ban_list:
                    if key not in result_op.keys():
                        result_op[key] = 0
                    if type(locals()[key])!=type(''):
                        result_op[key] += int(locals()[key])
        if self.op.type not in self.op_ban_list:
            sequence_temp_dict = dict()
            sequence_temp_dict['op_type'] = self.op.type
            for key in self.feature_dict.keys():
                sequence_temp_dict[key] = locals()[key] if key in locals() and locals()[key] else 0
            sequence.append(sequence_temp_dict)
